// Copyright (C) Kumo inc. and its affiliates.
// Author: Jeff.li lijippy@163.com
// All rights reserved.
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published
// by the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program.  If not, see <https://www.gnu.org/licenses/>.
//

#pragma once

#include <cstdint>
#include <memory>
#include <string>
#include <vector>

#include <nebula/core/buffer.h>
#include <nebula/core/compare.h>

#include <turbo/utility/status.h>
#include <nebula/types/type.h>
#include <nebula/types/type_traits.h>
#include <turbo/base/macros.h>

namespace nebula {

    static inline bool is_tensor_supported(Type::type type_id) {
        switch (type_id) {
            case Type::UINT8:
            case Type::INT8:
            case Type::UINT16:
            case Type::INT16:
            case Type::UINT32:
            case Type::INT32:
            case Type::UINT64:
            case Type::INT64:
            case Type::FP16:
            case Type::FP32:
            case Type::FP64:
                return true;
            default:
                break;
        }
        return false;
    }

    namespace internal {

        TURBO_EXPORT
        turbo::Status ComputeRowMajorStrides(const FixedWidthType &type,
                                             const std::vector<int64_t> &shape,
                                             std::vector<int64_t> *strides);

        TURBO_EXPORT
        turbo::Status ComputeColumnMajorStrides(const FixedWidthType &type,
                                                const std::vector<int64_t> &shape,
                                                std::vector<int64_t> *strides);

        TURBO_EXPORT
        bool IsTensorStridesContiguous(const std::shared_ptr<DataType> &type,
                                       const std::vector<int64_t> &shape,
                                       const std::vector<int64_t> &strides);

        TURBO_EXPORT
        turbo::Status ValidateTensorParameters(const std::shared_ptr<DataType> &type,
                                               const std::shared_ptr<Buffer> &data,
                                               const std::vector<int64_t> &shape,
                                               const std::vector<int64_t> &strides,
                                               const std::vector<std::string> &dim_names);

        TURBO_EXPORT
        turbo::Status RecordBatchToTensor(const RecordBatch &batch, bool null_to_nan, bool row_major,
                                          MemoryPool *pool, std::shared_ptr<Tensor> *tensor);

    }  // namespace internal

    class TURBO_EXPORT Tensor {
    public:
        /// \brief create a Tensor with full parameters
        ///
        /// This factory function will return turbo::invalid_argument_error when the parameters are
        /// inconsistent
        ///
        /// \param[in] type The data type of the tensor values
        /// \param[in] data The buffer of the tensor content
        /// \param[in] shape The shape of the tensor
        /// \param[in] strides The strides of the tensor
        ///            (if this is empty, the data assumed to be row-major)
        /// \param[in] dim_names The names of the tensor dimensions
        static inline turbo::Result<std::shared_ptr<Tensor>> create(
                const std::shared_ptr<DataType> &type, const std::shared_ptr<Buffer> &data,
                const std::vector<int64_t> &shape, const std::vector<int64_t> &strides = {},
                const std::vector<std::string> &dim_names = {}) {
            TURBO_RETURN_NOT_OK(
                    internal::ValidateTensorParameters(type, data, shape, strides, dim_names));
            return std::make_shared<Tensor>(type, data, shape, strides, dim_names);
        }

        virtual ~Tensor() = default;

        /// Constructor with no dimension names or strides, data assumed to be row-major
        Tensor(const std::shared_ptr<DataType> &type, const std::shared_ptr<Buffer> &data,
               const std::vector<int64_t> &shape);

        /// Constructor with non-negative strides
        Tensor(const std::shared_ptr<DataType> &type, const std::shared_ptr<Buffer> &data,
               const std::vector<int64_t> &shape, const std::vector<int64_t> &strides);

        /// Constructor with non-negative strides and dimension names
        Tensor(const std::shared_ptr<DataType> &type, const std::shared_ptr<Buffer> &data,
               const std::vector<int64_t> &shape, const std::vector<int64_t> &strides,
               const std::vector<std::string> &dim_names);

        std::shared_ptr<DataType> type() const { return type_; }

        std::shared_ptr<Buffer> data() const { return data_; }

        const uint8_t *raw_data() const { return data_->data(); }

        uint8_t *raw_mutable_data() { return data_->mutable_data(); }

        const std::vector<int64_t> &shape() const { return shape_; }

        const std::vector<int64_t> &strides() const { return strides_; }

        int ndim() const { return static_cast<int>(shape_.size()); }

        const std::vector<std::string> &dim_names() const { return dim_names_; }

        const std::string &dim_name(int i) const;

        /// Total number of value cells in the tensor
        int64_t size() const;

        /// Return true if the underlying data buffer is mutable
        bool is_mutable() const { return data_->is_mutable(); }

        /// Either row major or column major
        bool is_contiguous() const;

        /// AKA "C order"
        bool is_row_major() const;

        /// AKA "Fortran order"
        bool is_column_major() const;

        Type::type type_id() const;

        bool equals(const Tensor &other, const EqualOptions & = EqualOptions::defaults()) const;

        /// Compute the number of non-zero values in the tensor
        turbo::Result<int64_t> CountNonZero() const;

        /// Return the offset of the given index on the given strides
        static int64_t CalculateValueOffset(const std::vector<int64_t> &strides,
                                            const std::vector<int64_t> &index) {
            const int64_t n = static_cast<int64_t>(index.size());
            int64_t offset = 0;
            for (int64_t i = 0; i < n; ++i) {
                offset += index[i] * strides[i];
            }
            return offset;
        }

        int64_t CalculateValueOffset(const std::vector<int64_t> &index) const {
            return Tensor::CalculateValueOffset(strides_, index);
        }

        /// Returns the value at the given index without data-type and bounds checks
        template<typename VT>
        const typename VT::c_type &value(const std::vector<int64_t> &index) const {
            using c_type = typename VT::c_type;
            const int64_t offset = CalculateValueOffset(index);
            const c_type *ptr = reinterpret_cast<const c_type *>(raw_data() + offset);
            return *ptr;
        }

        turbo::Status validate() const {
            return internal::ValidateTensorParameters(type_, data_, shape_, strides_, dim_names_);
        }

    protected:
        Tensor() {}

        std::shared_ptr<DataType> type_;
        std::shared_ptr<Buffer> data_;
        std::vector<int64_t> shape_;
        std::vector<int64_t> strides_;

        /// These names are optional
        std::vector<std::string> dim_names_;

        template<typename SparseIndexType>
        friend
        class SparseTensorImpl;

    private:
        TURBO_DISALLOW_COPY_AND_ASSIGN(Tensor);
    };

    template<typename TYPE>
    class NumericTensor : public Tensor {
    public:
        using TypeClass = TYPE;
        using value_type = typename TypeClass::c_type;

        /// \brief create a NumericTensor with full parameters
        ///
        /// This factory function will return turbo::invalid_argument_error when the parameters are
        /// inconsistent
        ///
        /// \param[in] data The buffer of the tensor content
        /// \param[in] shape The shape of the tensor
        /// \param[in] strides The strides of the tensor
        ///            (if this is empty, the data assumed to be row-major)
        /// \param[in] dim_names The names of the tensor dimensions
        static turbo::Result<std::shared_ptr<NumericTensor<TYPE>>> create(
                const std::shared_ptr<Buffer> &data, const std::vector<int64_t> &shape,
                const std::vector<int64_t> &strides = {},
                const std::vector<std::string> &dim_names = {}) {
            TURBO_RETURN_NOT_OK(internal::ValidateTensorParameters(
                    TypeTraits<TYPE>::type_singleton(), data, shape, strides, dim_names));
            return std::make_shared<NumericTensor<TYPE>>(data, shape, strides, dim_names);
        }

        /// Constructor with non-negative strides and dimension names
        NumericTensor(const std::shared_ptr<Buffer> &data, const std::vector<int64_t> &shape,
                      const std::vector<int64_t> &strides,
                      const std::vector<std::string> &dim_names)
                : Tensor(TypeTraits<TYPE>::type_singleton(), data, shape, strides, dim_names) {}

        /// Constructor with no dimension names or strides, data assumed to be row-major
        NumericTensor(const std::shared_ptr<Buffer> &data, const std::vector<int64_t> &shape)
                : NumericTensor(data, shape, {}, {}) {}

        /// Constructor with non-negative strides
        NumericTensor(const std::shared_ptr<Buffer> &data, const std::vector<int64_t> &shape,
                      const std::vector<int64_t> &strides)
                : NumericTensor(data, shape, strides, {}) {}

        const value_type &value(const std::vector<int64_t> &index) const {
            return Tensor::value<TypeClass>(index);
        }
    };

}  // namespace nebula
